{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) Meta Platforms, Inc. and affiliates.\n",
    "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/finetuning/quickstart_peft_finetuning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PEFT Finetuning Quick Start Notebook\n",
    "\n",
    "This notebook shows how to train a Meta Llama 3 model on a single GPU (e.g. A10 with 24GB) using int8 quantization and LoRA finetuning.\n",
    "\n",
    "**_Note:_** To run this notebook on a machine with less than 24GB VRAM (e.g. T4 with 16GB) the context length of the training dataset needs to be adapted.\n",
    "We do this based on the available VRAM during execution.\n",
    "If you run into OOM issues try to further lower the value of train_config.context_length."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 0: Install pre-requirements and convert checkpoint\n",
    "\n",
    "We need to have llama-recipes and its dependencies installed for this notebook. Additionally, we need to log in with the huggingface_cli and make sure that the account is able to to access the Meta Llama weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# uncomment if running from Colab T4\n",
    "# ! pip install llama-recipes ipywidgets\n",
    "\n",
    "# import huggingface_hub\n",
    "# huggingface_hub.login()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Load the model\n",
    "\n",
    "Setup training configuration and load the model and tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "68838a4f42f84545912e95b339a31034",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import LlamaForCausalLM, AutoTokenizer\n",
    "from llama_recipes.configs import train_config as TRAIN_CONFIG\n",
    "\n",
    "train_config = TRAIN_CONFIG()\n",
    "train_config.model_name = \"meta-llama/Meta-Llama-3.1-8B\"\n",
    "train_config.num_epochs = 1\n",
    "train_config.run_validation = False\n",
    "train_config.gradient_accumulation_steps = 4\n",
    "train_config.batch_size_training = 1\n",
    "train_config.lr = 3e-4\n",
    "train_config.use_fast_kernels = True\n",
    "train_config.use_fp16 = True\n",
    "train_config.context_length = 1024 if torch.cuda.get_device_properties(0).total_memory < 16e9 else 2048 # T4 16GB or A10 24GB\n",
    "train_config.batching_strategy = \"packing\"\n",
    "train_config.output_dir = \"meta-llama-samsum\"\n",
    "train_config.use_peft = True\n",
    "\n",
    "from transformers import BitsAndBytesConfig\n",
    "config = BitsAndBytesConfig(\n",
    "    load_in_8bit=True,\n",
    ")\n",
    "\n",
    "model = LlamaForCausalLM.from_pretrained(\n",
    "            train_config.model_name,\n",
    "            device_map=\"auto\",\n",
    "            quantization_config=config,\n",
    "            use_cache=False,\n",
    "            attn_implementation=\"sdpa\" if train_config.use_fast_kernels else None,\n",
    "            torch_dtype=torch.float16,\n",
    "        )\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(train_config.model_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Check base model\n",
    "\n",
    "Run the base model on an example input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Summarize this dialog:\n",
      "A: Hi Tom, are you busy tomorrow’s afternoon?\n",
      "B: I’m pretty sure I am. What’s up?\n",
      "A: Can you go with me to the animal shelter?.\n",
      "B: What do you want to do?\n",
      "A: I want to get a puppy for my son.\n",
      "B: That will make him so happy.\n",
      "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n",
      "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n",
      "A: I'll get him one of those little dogs.\n",
      "B: One that won't grow up too big;-)\n",
      "A: And eat too much;-))\n",
      "B: Do you know which one he would like?\n",
      "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n",
      "B: I bet you had to drag him away.\n",
      "A: He wanted to take it home right away ;-).\n",
      "B: I wonder what he'll name it.\n",
      "A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))\n",
      "---\n",
      "Summary:\n",
      "A: Hi Tom, are you busy tomorrow’s afternoon?\n",
      "B: I’m pretty sure I am. What’s up?\n",
      "A: Can you go with me to the animal shelter?.\n",
      "B: What do you want to do?\n",
      "A: I want to get a puppy for my son.\n",
      "B: That will make him so happy.\n",
      "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n",
      "B: That’s good. Raising a dog is a tough issue\n"
     ]
    }
   ],
   "source": [
    "eval_prompt = \"\"\"\n",
    "Summarize this dialog:\n",
    "A: Hi Tom, are you busy tomorrow’s afternoon?\n",
    "B: I’m pretty sure I am. What’s up?\n",
    "A: Can you go with me to the animal shelter?.\n",
    "B: What do you want to do?\n",
    "A: I want to get a puppy for my son.\n",
    "B: That will make him so happy.\n",
    "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n",
    "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n",
    "A: I'll get him one of those little dogs.\n",
    "B: One that won't grow up too big;-)\n",
    "A: And eat too much;-))\n",
    "B: Do you know which one he would like?\n",
    "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n",
    "B: I bet you had to drag him away.\n",
    "A: He wanted to take it home right away ;-).\n",
    "B: I wonder what he'll name it.\n",
    "A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))\n",
    "---\n",
    "Summary:\n",
    "\"\"\"\n",
    "\n",
    "model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "model.eval()\n",
    "with torch.inference_mode():\n",
    "    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the base model only repeats the conversation.\n",
    "\n",
    "### Step 3: Load the preprocessed dataset\n",
    "\n",
    "We load and preprocess the samsum dataset which consists of curated pairs of dialogs and their summarization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead\n",
      "  from torch.distributed._shard.checkpoint import (\n",
      "Preprocessing dataset: 100%|██████████| 14732/14732 [00:02<00:00, 5872.02it/s]\n"
     ]
    }
   ],
   "source": [
    "from llama_recipes.configs.datasets import samsum_dataset\n",
    "from llama_recipes.utils.dataset_utils import get_dataloader\n",
    "\n",
    "samsum_dataset.trust_remote_code = True\n",
    "\n",
    "train_dataloader = get_dataloader(tokenizer, samsum_dataset, train_config)\n",
    "eval_dataloader = get_dataloader(tokenizer, samsum_dataset, train_config, \"val\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4: Prepare model for PEFT\n",
    "\n",
    "Let's prepare the model for Parameter Efficient Fine Tuning (PEFT):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import get_peft_model, prepare_model_for_kbit_training, LoraConfig\n",
    "from dataclasses import asdict\n",
    "from llama_recipes.configs import lora_config as LORA_CONFIG\n",
    "\n",
    "lora_config = LORA_CONFIG()\n",
    "lora_config.r = 8\n",
    "lora_config.lora_alpha = 32\n",
    "lora_dropout: float=0.01\n",
    "\n",
    "peft_config = LoraConfig(**asdict(lora_config))\n",
    "\n",
    "model = prepare_model_for_kbit_training(model)\n",
    "model = get_peft_model(model, peft_config)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 5: Fine tune the model\n",
    "\n",
    "Here, we fine tune the model for a single epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/llama-recipes/src/llama_recipes/utils/train_utils.py:92: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
      "  scaler = torch.cuda.amp.GradScaler()\n",
      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
      "  warnings.warn(\n",
      "Training Epoch: 1:   0%|\u001b[34m          \u001b[0m| 0/319 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "/home/ubuntu/llama-recipes/src/llama_recipes/utils/train_utils.py:151: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
      "  with autocast():\n",
      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
      "  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]\n",
      "Training Epoch: 1/1, step 1278/1279 completed (loss: 0.28094857931137085): : 320it [2:08:50, 24.16s/it]                      4.21s/it]  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Max CUDA memory allocated was 15 GB\n",
      "Max CUDA memory reserved was 16 GB\n",
      "Peak active CUDA memory was 15 GB\n",
      "CUDA Malloc retries : 0\n",
      "CPU Total Peak Memory consumed during the train (max): 2 GB\n",
      "Epoch 1: train_perplexity=1.3404, train_epoch_loss=0.2930, epoch time 7730.981359725998s\n"
     ]
    }
   ],
   "source": [
    "import torch.optim as optim\n",
    "from llama_recipes.utils.train_utils import train\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "\n",
    "model.train()\n",
    "\n",
    "optimizer = optim.AdamW(\n",
    "            model.parameters(),\n",
    "            lr=train_config.lr,\n",
    "            weight_decay=train_config.weight_decay,\n",
    "        )\n",
    "scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)\n",
    "\n",
    "# Start the training process\n",
    "results = train(\n",
    "    model,\n",
    "    train_dataloader,\n",
    "    eval_dataloader,\n",
    "    tokenizer,\n",
    "    optimizer,\n",
    "    scheduler,\n",
    "    train_config.gradient_accumulation_steps,\n",
    "    train_config,\n",
    "    None,\n",
    "    None,\n",
    "    None,\n",
    "    wandb_run=None,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 6:\n",
    "Save model checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(train_config.output_dir)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 7:\n",
    "Try the fine tuned model on the same example again to see the learning progress:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Summarize this dialog:\n",
      "A: Hi Tom, are you busy tomorrow’s afternoon?\n",
      "B: I’m pretty sure I am. What’s up?\n",
      "A: Can you go with me to the animal shelter?.\n",
      "B: What do you want to do?\n",
      "A: I want to get a puppy for my son.\n",
      "B: That will make him so happy.\n",
      "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n",
      "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n",
      "A: I'll get him one of those little dogs.\n",
      "B: One that won't grow up too big;-)\n",
      "A: And eat too much;-))\n",
      "B: Do you know which one he would like?\n",
      "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n",
      "B: I bet you had to drag him away.\n",
      "A: He wanted to take it home right away ;-).\n",
      "B: I wonder what he'll name it.\n",
      "A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))\n",
      "---\n",
      "Summary:\n",
      "A wants to get a puppy for his son. A took him to the animal shelter last Monday and he showed A one he really liked. A wants to get him one of those little dogs. A and B agree that raising a dog is a tough issue.\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "with torch.inference_mode():\n",
    "    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
